

# ################################################
# (3) Hybrid likelihood batch prediction problem 
# ################################################

get_pi <- function(m, lambda, dat){
 n = nrow(dat)
 p = 1/(n + n*lambda*m)
 return(p)
}

get_lambda <- function(m, dat){
 
 fun <- function(lambda) sum(m/(1 + lambda*m))
 
 n = nrow(dat)
 min_m = min(m)
 max_m = max(m)
 
 if (min_m*max_m < 0){
  lb = (1/n - 1)/max_m
  ub = (1/n - 1)/min_m
 }else if(min_m > 0){ 
  cat("positive \n")
  lb = (1/n - 1)/max_m
  ub = Inf
 }else{ # max_m < 0
  cat("negative \n")
  lb = -Inf
  ub = (1/n - 1)/min_m
 }
 interval = c(lb, ub)
 lambda = uniroot(fun, interval)$root
 return(lambda)
}

get_m <- function(beta, dat, Xm, Xy){
 
 n = nrow(dat)
 beta_m = beta[1:ncol(Xm)]
 beta_y = beta[(ncol(Xm)+1):length(beta)]
 
 dat_a0m0 = process_data(dat, a = 0, m = 0)
 dat_a0m1 = process_data(dat, a = 0, m = 1)
 dat_a1m0 = process_data(dat, a = 1, m = 0)
 dat_a1m1 = process_data(dat, a = 1, m = 1)
 
 idx_m = which(colnames(Xm) %in% colnames(dat))
 p_m1a0 = 1/(1 + exp(-dat_a0m1[, idx_m]%*%beta_m))
 p_m0a0 = 1 - p_m1a0
 p_ma0 = p_m1a0
 p_ma0[dat$M == 0] = p_m0a0[dat$M == 0]
 
 # cat(head(Xy), "\n")
 idx_y = which(colnames(Xy) %in% colnames(dat))
 y_a0m0 = dat_a0m0[, idx_y]%*%beta_y
 y_a0m1 = dat_a0m1[, idx_y]%*%beta_y
 y_a1m0 = dat_a1m0[, idx_y]%*%beta_y
 y_a1m1 = dat_a1m1[, idx_y]%*%beta_y
 
 m = (y_a1m0 - y_a0m0)*p_m0a0 + (y_a1m1 - y_a0m1)*p_m1a0 
 return(m)
}


get_negloglik <- function(beta, dat, idx_test, Xm, Xy){
 n = nrow(dat)
 m = get_m(beta, dat, Xm, Xy)
 lambda = get_lambda(m, dat)
 beta_m = beta[1:(ncol(Xm))]
 beta_y = beta[(ncol(Xm)+1):length(beta)]
 
 Y = dat$Y
 Y_hat = Xy%*%beta_y
 Y[idx_test] = Y_hat[idx_test]

 p_X = get_pi(m, lambda, dat)
 p_Y = dnorm(Y, Y_hat, 1)
 
 p_M1 = 1/(1+exp(-Xm%*%beta_m))
 p_M = M*p_M1 + (1-M)*(1-p_M1)
 
 loglik = sum( log(p_X) + log(p_Y) + log(p_M))
 
 # loglik = sum( - log(1 + lambda*m) 
 #               - (Y - Y_hat)^2/2 
 #               - dat$M*log(1 + exp(-Xm%*%beta_m)) - (1-dat$M)*log(1 + exp(Xm%*%beta_m)))
 return(-loglik)
}

get_gradient_negloglik <- function(beta, dat, idx_test, Xm, Xy, delta){
 jacob = c()
 for (i in 1:length(beta)){
  beta_tmp_l = beta
  beta_tmp_u = beta
  beta_tmp_l[i] = beta[i] - delta
  beta_tmp_u[i] = beta[i] + delta
  grad = (get_negloglik(beta_tmp_u, dat, idx_test, Xm, Xy) - get_negloglik(beta_tmp_l, dat, idx_test, Xm, Xy))/(2*delta)
  jacob = c(jacob, grad)
 }
 return(jacob)
}

optimize_hybrid <- function(beta, dat, idx_test, fmla_m, fmla_y, opt){
 
 alpha = opt$alpha
 threshold = opt$threshold
 max_iter = opt$max_iter
 delta = opt$delta

 Xm = as.matrix(model.matrix(fmla_m, data=model.frame(dat, na.action = NULL)))
 Xy = as.matrix(model.matrix(fmla_y, data=model.frame(dat, na.action = NULL)))
 names(beta) = c(colnames(Xm), colnames(Xy))

 for (i in 1:max_iter){
  # cat(i, "\n")
  lr = alpha/i
  grad = get_gradient_negloglik(beta, dat, idx_test, Xm, Xy, delta)
  beta_next = beta - lr*grad
  
  if (all(abs(beta_next - beta) < threshold) | all(abs(grad) < threshold)) break
  
  # added
  m = get_m(beta_next, dat, Xm, Xy)
  lambda = get_lambda(m, dat)
  pi = get_pi(m, lambda, dat)
  nde = sum(m*pi)
  cat("nde = ", nde, "\n")
  # if(nde < 0.0001) break
  # end added
  
  beta = beta_next
  if (i == max_iter) cat("reached maximum number of iterations without convergence \n")
 }
 
 beta_m = beta[1:ncol(Xm)]
 beta_y = beta[(ncol(Xm)+1):length(beta)]
 names(beta_m) = colnames(Xm)
 names(beta_y) = colnames(Xy)
 
 m = get_m(beta, dat, Xm, Xy)
 lambda = get_lambda(m, dat)
 pi = get_pi(m, lambda, dat)
 
 nde = sum(m*pi)
 
 Y_hat = Xy%*%beta_y
 
 neg_lok_lik_YMX = get_negloglik(beta, dat, idx_test, Xm, Xy)
 log_lik = -neg_lok_lik_YMX #- n*log(sqrt(2*3.1416))
 
 return(list(beta_m=beta_m, 
             beta_y=beta_y, 
             m=m, 
             lambda=lambda, 
             px=pi, 
             Y_hat=Y_hat, 
             mle = log_lik, 
             nde = nde))
}
 

 
